import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F
import numpy as np
from generate_T import Generator_matrix
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class EACR(nn.Module):
    def __init__(self, image_size, patch_size, dim, num_heads, num_layers, input_dim1, input_dim2, hidden_dim, dropout_prob, embed_dim):
        super().__init__()
        size, l, m = 784, 28, 28
        T_numpy = Generator_matrix(size, l, m)
        self.T = torch.tensor(T_numpy, dtype=torch.float32).to(device)
        np.savetxt('T_numpy.txt', T_numpy, fmt='%d')
        self.embedding = nn.Embedding(input_dim1, hidden_dim)
        self.activate = nn.Tanh()
        #self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)

        self.pool_1= nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv1_1 = nn.Conv2d(1,4,kernel_size=4,stride=2,padding=2)
        self.conv2_1= nn.Conv2d(4,8,kernel_size=4,stride=2,padding=2)
        self.conv3_1= nn.Conv2d(8,16,kernel_size=4,stride=2,padding=2)
        self.conv4_1= nn.Conv2d(16,16,kernel_size=4,stride=2,padding=2)

        self.pool_2 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv1_2 = nn.Conv2d(1,4,kernel_size=4,stride=2,padding=2)
        self.conv2_2 = nn.Conv2d(4,8,kernel_size=4,stride=2,padding=2)
        self.conv3_2 = nn.Conv2d(8,16,kernel_size=4,stride=2,padding=2)
        self.conv4_2 = nn.Conv2d(16,16,kernel_size=4,stride=2,padding=2)

        self.fusion_weight = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
        
        self.fc_residual = nn.Sequential(
                                nn.Linear(784, 26),
                                nn.ReLU(),
                                nn.Dropout(p=dropout_prob),
                                nn.Linear(26, 10)
                                )
        self.fc = nn.Linear(288,10)#(dim * (h // patch_size) * (w // patch_size), 512)
        
    def forward(self, x):
        E = self.embedding(x)
        #x_pooled = self.global_max_pool(E)  
        x_pooled = self.global_avg_pool(E)  
        x_pooled = x_pooled.squeeze(-1) 
        R = self.fc_residual(x_pooled)
        #R = self.fc_residual(x.float())
        A= torch.matmul(E, E.transpose(-2, -1))
        AA = torch.matmul(self.T, A)
        AA = torch.matmul(self.T, AA.transpose(-2, -1))
        A = self.activate(A).unsqueeze(1)
        AA = self.activate(AA).unsqueeze(1)

        b, c, h, w = A.shape
        b_, c_, h_, w_ = AA.shape
        
        x_tmp = F.relu(self.conv1_1(A))
        x_tmp = self.pool_1(x_tmp)
        x_tmp = F.relu(self.conv2_1(x_tmp))
        x_tmp = self.pool_1(x_tmp)
        x_tmp = F.relu(self.conv3_1(x_tmp))
        x_tmp = self.pool_1(x_tmp)
        x_tmp = F.relu(self.conv4_1(x_tmp))
        x_tmp = self.pool_1(x_tmp)
        
        x_tmp_ = F.relu(self.conv1_2(AA))
        x_tmp_ = self.pool_2(x_tmp_)
        x_tmp_ = F.relu(self.conv2_2(x_tmp_))
        x_tmp_ = self.pool_2(x_tmp_)
        x_tmp_ = F.relu(self.conv3_2(x_tmp_))
        x_tmp_ = self.pool_2(x_tmp_)
        x_tmp_ = F.relu(self.conv4_2(x_tmp_))
        x_tmp_ = self.pool_2(x_tmp_)
        
        x_concat = torch.cat((x_tmp, x_tmp_), dim=1)
        C = x_concat.reshape(b, -1)
        C = self.fc(C)
        fusion_weight = torch.sigmoid(self.fusion_weight)
        out = fusion_weight * C + (1 - fusion_weight) * R
        #xxx_ = self.relu2(xxx)
        #x = self.dro(x)
        #x = self.fc_2(x)
        return out,A